We can read in the embedding file, which has the outlier protein sequences detected by manual inspection of MDS plots
embeds <- readRDS("data/tmp/embeds_with_mds.rds")
table(embeds$ManualOutlier)
FALSE TRUE
647389 3367
dim(embeds)
[1] 650756 969
embed_cols <- grep("embedding", colnames(embeds), value=TRUE)
clean_embeds <- embeds[ManualOutlier == FALSE]
head(clean_embeds[,1:5])
#write_parquet(clean_embeds, "data/clean_embeds.parquet")
#clean_seq <- clean_embeds[,c("ID", "Taxonomy", "Gene", "AA_seq")]
#write_parquet(clean_seq, "data/clean_AA_seqs.parquet")
#head(clean_seq)
And now we can read in the data with phenotypic values
df <- read_parquet("data/processed_data.parquet")
setDT(df)
#pheno <- clean_data[,c("ID", "pheno_Topt_site_p50")]
#write_parquet(pheno, "data/pheno_topt_clean.parquet")
Now, in the meantime, I wish to analyze the correlation between the embeddings and my phenotypes.
genes <- unique(clean_embeds$Gene)
all_cors <- list()
for (gene in genes) {
gene_data <- clean_embeds[grep(gene, clean_embeds$Gene),]
pheno <- df[,c("ID", "pheno_wc2.1_2.5m_bio_8_p50")]
gene_data <- merge(gene_data, pheno, by="ID")
cors <- sapply(embed_cols, function(col)
cor(gene_data[[col]], gene_data$pheno_wc2.1_2.5m_bio_8_p50, use="complete.obs")
)
hist(cors, main=gene)
all_cors[[gene]] <- cors
}
par(mfrow=c(1,2))
hist(all_cors$psbN, main="Cor of psbN embeds with bio8")
hist(all_cors$rbcL, main="rbcL")
gene_data <- clean_embeds[grep("psaC", clean_embeds$Gene),]
par(mfrow=c(1,2))
hist(gene_data$MDS1)
plot(gene_data$MDS1, gene_data$MDS2, main="psaC MDS results")
#hist(gene_data$MDS2)
par(mfrow=c(1,1))
for (ord in unique(df$Order)) {
order_ids <- df[grep(ord, df$Order),"ID"]
order_subset <- gene_data[gene_data$ID %in% order_ids$ID, ]
hist(order_subset$MDS1,main=ord,xlim=c(-.1,.1))
}
gene_data <- clean_embeds[grep("psaC", clean_embeds$Gene),]
merged <- merge(gene_data, df[,c("ID","Order")], by="ID")
boxplot(MDS1 ~ Order, data=merged,
main="psaC MDS1 by Order",
las=2, outline=FALSE)
library(pheatmap)
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)
pheatmap(mat, color=colorRampPalette(c("blue","white","red"))(100))
# build correlation matrix: genes x embedding dimensions
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)
# similarity between genes: correlation of their correlation profiles
gene_sim <- cor(t(mat), use="pairwise.complete.obs")
# hierarchical clustering heatmap
pheatmap(gene_sim,
main="Similarity of gene embeddings wrt bio8",
color=colorRampPalette(c("blue","white","red"))(100))
hist(gene_sim)
offdiag <- gene_sim[upper.tri(gene_sim)]
hist(offdiag, breaks=30,
main="Stability of embedding correlations across genes",
xlab="Pairwise correlation",
col="skyblue")
abline(v=mean(offdiag, na.rm=TRUE), col="red", lwd=2)
hist(length(clean_embeds$psaC_CDS)/3)
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)
# per-embedding stats across genes
embed_stats <- data.frame(
dim = colnames(mat),
mean_cor = apply(mat, 2, mean, na.rm=TRUE),
mean_abs_cor = apply(mat, 2, function(x) mean(abs(x), na.rm=TRUE)),
sd_cor = apply(mat, 2, sd, na.rm=TRUE)
)
plot(embed_stats$mean_cor, embed_stats$sd_cor,
xlab="mean of correlation across genes",
ylab="sd of correlation across genes")
abline(a=0,b=1)
# pick embedding with strongest + most stable signal
best <- embed_stats[order(-embed_stats$mean_abs_cor, embed_stats$sd_cor), ][1, ]
best
library(stats)
results <- list()
for (gene in genes) {
gene_data <- clean_embeds[Gene == gene]
pheno <- df[,c("ID","pheno_wc2.1_2.5m_bio_8_p50")]
merged <- merge(gene_data, pheno, by="ID")
cors <- sapply(embed_cols, function(col)
cor(merged[[col]], merged$pheno_wc2.1_2.5m_bio_8_p50, use="complete.obs"))
top_dims <- names(sort(abs(cors), decreasing=TRUE))[1:5]
formula_str <- paste("pheno_wc2.1_2.5m_bio_8_p50 ~", paste(top_dims, collapse=" + "))
fit <- lm(as.formula(formula_str), data=merged)
plot(fit$fitted.values, merged$pheno_wc2.1_2.5m_bio_8_p50)
results[[gene]] <- list(
cors=cors,
top_dims=top_dims,
model=summary(fit)
)
}
results$psaC$top_dims
[1] "embedding_59" "embedding_118" "embedding_459" "embedding_37" "embedding_220"
results$psaC$model
Call:
lm(formula = as.formula(formula_str), data = merged)
Residuals:
Min 1Q Median 3Q Max
-25.935 -4.952 1.659 4.864 16.783
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 10.073 4.149 2.428 0.015204 *
embedding_59 751.717 117.098 6.420 1.42e-10 ***
embedding_118 -155.406 153.138 -1.015 0.310218
embedding_459 -633.788 188.642 -3.360 0.000783 ***
embedding_37 -252.371 131.748 -1.916 0.055447 .
embedding_220 245.807 290.302 0.847 0.397165
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 6.416 on 10761 degrees of freedom
Multiple R-squared: 0.02567, Adjusted R-squared: 0.02521
F-statistic: 56.69 on 5 and 10761 DF, p-value: < 2.2e-16
# per-gene variable selection and building a combined design matrix
library(data.table)
library(stats)
# ensure data.tables
setDT(clean_embeds)
setDT(df)
# phenotype column name (adjust if you want a different pheno)
pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"
# number of top dims per gene
n_top <- 5
# container for per-gene selected data.tables (ID + renamed top embeds)
sel_list <- vector("list", length(genes))
names(sel_list) <- genes
for (gene in genes) {
# subset gene
gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
# merge with phenotype (inner join to ensure measurable correlation)
gdt <- merge(gdt, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = FALSE, all.y = FALSE)
# if too few rows, skip (or store NA)
if (nrow(gdt) < 5) {
warning(sprintf("Gene %s has only %d rows; skipping.", gene, nrow(gdt)))
next
}
# compute correlations (use complete.obs)
cors <- sapply(embed_cols, function(col) cor(gdt[[col]], gdt$pheno, use = "complete.obs"))
# pick top dims by absolute correlation (handle if fewer than n_top dims available)
available <- names(cors)[!is.na(cors)]
k <- min(n_top, length(available))
top_dims <- names(sort(abs(cors[available]), decreasing = TRUE))[1:k]
# select ID + these dims and rename dims to gene__dim
sel <- gdt[, c("ID", top_dims), with = FALSE]
newnames <- setNames(top_dims, paste0(gene, "__", top_dims))
setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
# keep only ID + renamed columns (drop phenotype copy)
sel <- sel[, c("ID", names(newnames)), with = FALSE]
sel_list[[gene]] <- sel
}
# remove genes we skipped
sel_list <- sel_list[!sapply(sel_list, is.null)]
# merge all per-gene tables by ID, keeping only IDs present in ALL (intersection)
if (length(sel_list) == 0) stop("No genes with selected dims found.")
combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = FALSE), sel_list)
# bring phenotype back
combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE, all.y = FALSE)
# quick checks
cat("Samples (IDs) in combined matrix:", nrow(combined), "\n")
Samples (IDs) in combined matrix: 5907
cat("Number of predictors:", ncol(combined) - 2, "(ID and pheno excluded)\n") # -2 for ID and pheno
Number of predictors: 305 (ID and pheno excluded)
if (inherits(lm_fit, "try-error")) {
warning("lm failed (probably too many predictors). See glmnet example below.")
} else {
print(summary(lm_fit))
}
Call:
lm(formula = lm_formula, data = combined)
Residuals:
Min 1Q Median 3Q Max
-24.1608 -2.9069 0.4772 3.1734 22.0632
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 76.8254 55.6181 1.381 0.167241
atpA__embedding_876 -208.8354 242.6664 -0.861 0.389503
atpA__embedding_94 -575.7553 160.9851 -3.576 0.000351 ***
atpA__embedding_741 -179.2769 104.7317 -1.712 0.086994 .
atpA__embedding_201 777.5276 123.6092 6.290 3.41e-10 ***
atpA__embedding_232 -262.7959 141.4287 -1.858 0.063200 .
atpB__embedding_224 231.1105 145.2126 1.592 0.111546
atpB__embedding_514 910.9106 238.5938 3.818 0.000136 ***
atpB__embedding_634 -93.0585 138.4092 -0.672 0.501393
atpB__embedding_695 269.1380 145.7439 1.847 0.064851 .
atpB__embedding_582 531.0567 213.9233 2.482 0.013077 *
atpE__embedding_781 -107.8754 137.9126 -0.782 0.434130
atpE__embedding_254 35.7627 102.8618 0.348 0.728096
atpE__embedding_17 65.8292 111.7574 0.589 0.555860
atpE__embedding_633 187.9588 114.2747 1.645 0.100068
atpE__embedding_548 142.9291 110.8520 1.289 0.197323
atpF__embedding_739 62.8667 128.4662 0.489 0.624603
atpF__embedding_721 -130.7040 68.3097 -1.913 0.055748 .
atpF__embedding_522 -46.2784 100.6542 -0.460 0.645694
atpF__embedding_720 -95.5329 53.2440 -1.794 0.072828 .
atpF__embedding_234 -96.1766 79.0348 -1.217 0.223698
atpH__embedding_248 -392.6980 664.4051 -0.591 0.554509
atpH__embedding_583 1072.0269 906.1453 1.183 0.236834
atpH__embedding_46 -726.1298 1140.4567 -0.637 0.524346
atpH__embedding_310 -1483.3559 789.0042 -1.880 0.060155 .
atpH__embedding_256 449.7932 761.7268 0.590 0.554885
atpI__embedding_282 -198.1053 148.1331 -1.337 0.181164
atpI__embedding_349 -86.2666 135.7952 -0.635 0.525279
atpI__embedding_823 -548.2127 190.5590 -2.877 0.004032 **
atpI__embedding_599 153.3643 173.1258 0.886 0.375734
atpI__embedding_324 -320.3872 169.5294 -1.890 0.058828 .
ccsA__embedding_736 -300.9649 87.9722 -3.421 0.000628 ***
ccsA__embedding_563 -175.8039 60.1767 -2.921 0.003498 **
ccsA__embedding_196 -142.1579 61.1807 -2.324 0.020184 *
ccsA__embedding_710 165.9675 68.8158 2.412 0.015907 *
ccsA__embedding_744 54.4999 77.4317 0.704 0.481558
cemA__embedding_764 98.3458 103.4912 0.950 0.342010
cemA__embedding_402 101.9166 62.3800 1.634 0.102356
cemA__embedding_480 38.0450 51.3430 0.741 0.458727
cemA__embedding_836 156.2334 84.3986 1.851 0.064202 .
cemA__embedding_730 -178.8383 69.4146 -2.576 0.010010 *
matK__embedding_867 186.8447 72.8111 2.566 0.010309 *
matK__embedding_869 80.5346 99.9816 0.805 0.420569
matK__embedding_655 -270.9279 88.1532 -3.073 0.002127 **
matK__embedding_272 61.9943 85.9940 0.721 0.470993
matK__embedding_637 -96.6917 111.5608 -0.867 0.386134
ndhA__embedding_204 139.1071 91.9982 1.512 0.130574
ndhA__embedding_328 272.3231 113.4779 2.400 0.016437 *
ndhA__embedding_924 -189.1086 117.8001 -1.605 0.108476
ndhA__embedding_584 108.4594 94.1441 1.152 0.249347
ndhA__embedding_637 339.3575 97.9665 3.464 0.000536 ***
ndhB__embedding_16 63.4961 251.6460 0.252 0.800801
ndhB__embedding_829 -330.2832 273.7211 -1.207 0.227621
ndhB__embedding_19 -286.1228 281.0232 -1.018 0.308652
ndhB__embedding_701 13.1022 247.8648 0.053 0.957845
ndhB__embedding_404 -36.2425 224.9070 -0.161 0.871986
ndhC__embedding_172 -5.2171 112.4138 -0.046 0.962985
ndhC__embedding_495 -381.2100 149.7819 -2.545 0.010951 *
ndhC__embedding_619 80.1857 149.1111 0.538 0.590765
ndhC__embedding_656 207.2433 158.5764 1.307 0.191301
ndhC__embedding_939 100.8045 203.1951 0.496 0.619845
ndhD__embedding_610 -5.3162 96.6444 -0.055 0.956134
ndhD__embedding_154 65.7973 102.9415 0.639 0.522737
ndhD__embedding_16 -22.1424 76.1572 -0.291 0.771257
ndhD__embedding_126 -333.0086 107.4735 -3.099 0.001955 **
ndhD__embedding_493 129.2113 89.0514 1.451 0.146843
ndhE__embedding_394 -235.0752 130.0469 -1.808 0.070719 .
ndhE__embedding_291 215.9190 87.1772 2.477 0.013287 *
ndhE__embedding_36 250.5818 107.1751 2.338 0.019419 *
ndhE__embedding_327 194.3733 111.9194 1.737 0.082491 .
ndhE__embedding_792 207.9030 117.2270 1.774 0.076199 .
ndhG__embedding_936 -21.4389 69.0997 -0.310 0.756374
ndhG__embedding_616 -235.8352 102.3305 -2.305 0.021223 *
ndhG__embedding_689 150.3810 100.8533 1.491 0.135995
ndhG__embedding_651 138.8690 58.8018 2.362 0.018228 *
ndhG__embedding_600 -154.4796 92.1588 -1.676 0.093748 .
ndhH__embedding_448 27.2937 244.6118 0.112 0.911161
ndhH__embedding_958 154.8204 218.6563 0.708 0.478942
ndhH__embedding_844 269.3165 131.0652 2.055 0.039942 *
ndhH__embedding_282 34.9133 162.1551 0.215 0.829535
ndhH__embedding_180 -154.7166 117.3752 -1.318 0.187512
ndhI__embedding_356 -1.5720 105.6532 -0.015 0.988129
ndhI__embedding_247 549.0817 124.7357 4.402 1.09e-05 ***
ndhI__embedding_729 5.9835 160.1021 0.037 0.970189
ndhI__embedding_854 232.8629 113.8125 2.046 0.040801 *
ndhI__embedding_802 -277.7352 107.0273 -2.595 0.009484 **
ndhJ__embedding_910 65.5178 120.5850 0.543 0.586923
ndhJ__embedding_478 -370.7750 122.5513 -3.025 0.002494 **
ndhJ__embedding_719 822.2006 169.5346 4.850 1.27e-06 ***
ndhJ__embedding_220 60.0519 124.3602 0.483 0.629195
ndhJ__embedding_728 -263.0135 91.7493 -2.867 0.004164 **
ndhK__embedding_939 -84.2404 108.0653 -0.780 0.435699
ndhK__embedding_517 -9.8507 106.7471 -0.092 0.926478
ndhK__embedding_506 126.2942 140.7308 0.897 0.369535
ndhK__embedding_651 51.5243 82.6169 0.624 0.532880
ndhK__embedding_323 91.6913 90.6985 1.011 0.312086
petA__embedding_612 334.5739 153.5199 2.179 0.029347 *
petA__embedding_729 -602.0623 184.7791 -3.258 0.001128 **
petA__embedding_897 -453.2711 148.9370 -3.043 0.002350 **
petA__embedding_324 -652.7140 156.1078 -4.181 2.94e-05 ***
petA__embedding_319 386.2192 139.2024 2.775 0.005547 **
petG__embedding_425 84.5845 262.7228 0.322 0.747500
petG__embedding_374 -59.6645 315.6244 -0.189 0.850071
petG__embedding_521 434.0284 318.9684 1.361 0.173655
petG__embedding_915 615.3282 292.1598 2.106 0.035237 *
petG__embedding_951 623.7942 217.1642 2.872 0.004088 **
petN__embedding_958 -58.0280 363.3937 -0.160 0.873136
petN__embedding_489 -110.4462 313.8402 -0.352 0.724913
petN__embedding_834 60.7427 443.6018 0.137 0.891090
petN__embedding_36 -738.7642 526.0309 -1.404 0.160252
petN__embedding_365 82.3972 408.3399 0.202 0.840091
psaA__embedding_484 1108.8366 388.8539 2.852 0.004367 **
psaA__embedding_771 364.4961 282.5170 1.290 0.197044
psaA__embedding_829 237.4691 404.4587 0.587 0.557141
psaA__embedding_690 -112.0831 297.0699 -0.377 0.705969
psaA__embedding_434 219.3752 308.5731 0.711 0.477154
psaB__embedding_765 -281.7155 277.6143 -1.015 0.310258
psaB__embedding_239 1715.7434 436.0704 3.935 8.44e-05 ***
psaB__embedding_44 -750.2019 342.0719 -2.193 0.028340 *
psaB__embedding_366 21.2841 256.0890 0.083 0.933765
psaB__embedding_707 156.7416 338.5362 0.463 0.643384
psaC__embedding_59 -363.0662 304.2073 -1.193 0.232731
psaC__embedding_118 -199.1120 360.1609 -0.553 0.580394
psaC__embedding_459 -387.7371 319.8318 -1.212 0.225443
psaC__embedding_37 22.7648 314.5797 0.072 0.942313
psaC__embedding_220 500.4841 534.1604 0.937 0.348822
psaJ__embedding_168 89.3835 129.4254 0.691 0.489834
psaJ__embedding_5 348.1794 130.4995 2.668 0.007651 **
psaJ__embedding_931 163.9249 120.7750 1.357 0.174748
psaJ__embedding_281 -194.2151 129.4487 -1.500 0.133587
psaJ__embedding_189 182.5217 146.1124 1.249 0.211649
psbA__embedding_514 1285.7125 533.1846 2.411 0.015924 *
psbA__embedding_553 -541.1976 352.4451 -1.536 0.124705
psbA__embedding_950 198.7766 342.1091 0.581 0.561242
psbA__embedding_414 278.7585 374.8898 0.744 0.457165
psbA__embedding_786 194.5895 375.2246 0.519 0.604064
psbB__embedding_269 -91.2722 302.1339 -0.302 0.762593
psbB__embedding_221 1010.9740 296.1797 3.413 0.000646 ***
psbB__embedding_854 226.2997 212.3829 1.066 0.286683
psbB__embedding_655 1445.8331 273.5239 5.286 1.30e-07 ***
psbB__embedding_217 330.6320 247.2457 1.337 0.181192
psbC__embedding_438 1593.7468 334.7349 4.761 1.97e-06 ***
psbC__embedding_30 828.2348 233.5209 3.547 0.000393 ***
psbC__embedding_394 -1446.6585 408.0907 -3.545 0.000396 ***
psbC__embedding_836 640.6248 331.8060 1.931 0.053568 .
psbC__embedding_251 -606.9502 431.3952 -1.407 0.159499
psbD__embedding_351 -497.3078 298.4608 -1.666 0.095721 .
psbD__embedding_814 1073.5785 468.3372 2.292 0.021924 *
psbD__embedding_574 -361.4385 132.2990 -2.732 0.006315 **
psbD__embedding_702 -375.5442 474.4349 -0.792 0.428650
psbD__embedding_352 261.6981 366.9131 0.713 0.475725
psbE__embedding_261 -60.4732 195.3587 -0.310 0.756915
psbE__embedding_480 131.1661 90.3575 1.452 0.146659
psbE__embedding_143 114.7043 274.4804 0.418 0.676039
psbE__embedding_724 -64.5953 260.7966 -0.248 0.804388
psbE__embedding_152 -367.9184 257.4361 -1.429 0.153013
psbF__embedding_163 753.9860 266.0071 2.834 0.004607 **
psbF__embedding_595 317.8768 254.1353 1.251 0.211053
psbF__embedding_406 -990.0347 396.2723 -2.498 0.012505 *
psbF__embedding_848 77.8062 266.0174 0.292 0.769926
psbF__embedding_335 375.3744 195.6757 1.918 0.055117 .
psbH__embedding_712 -12.6540 92.8964 -0.136 0.891656
psbH__embedding_478 -179.3157 114.4757 -1.566 0.117310
psbH__embedding_656 -19.7671 103.5579 -0.191 0.848627
psbH__embedding_305 77.4610 81.8522 0.946 0.344010
psbH__embedding_481 279.7592 88.2166 3.171 0.001526 **
psbI__embedding_490 -1729.5498 414.8573 -4.169 3.11e-05 ***
psbI__embedding_728 -29.2831 319.8111 -0.092 0.927048
psbI__embedding_1 309.0379 448.6852 0.689 0.491001
psbI__embedding_890 40.1530 370.3839 0.108 0.913675
psbI__embedding_151 -1164.5657 352.0678 -3.308 0.000946 ***
psbJ__embedding_72 274.8862 237.2479 1.159 0.246650
psbJ__embedding_286 8.8962 196.0862 0.045 0.963815
psbJ__embedding_5 366.8672 225.3730 1.628 0.103619
psbJ__embedding_24 41.2185 27.4781 1.500 0.133658
psbJ__embedding_924 53.8107 166.4295 0.323 0.746462
psbK__embedding_53 190.3357 91.6326 2.077 0.037832 *
psbK__embedding_723 227.7159 77.4277 2.941 0.003285 **
psbK__embedding_255 -127.0001 91.6073 -1.386 0.165694
psbK__embedding_892 -80.0394 72.4595 -1.105 0.269377
psbK__embedding_222 -38.7123 77.1417 -0.502 0.615804
psbM__embedding_53 -113.9512 145.9467 -0.781 0.434969
psbM__embedding_570 642.2320 196.3048 3.272 0.001076 **
psbM__embedding_475 127.5162 118.5945 1.075 0.282319
psbM__embedding_646 -44.4199 196.6175 -0.226 0.821272
psbM__embedding_959 -243.3920 124.5491 -1.954 0.050729 .
psbN__embedding_527 440.9233 211.7174 2.083 0.037333 *
psbN__embedding_815 -536.4711 219.9086 -2.440 0.014738 *
psbN__embedding_156 21.7603 255.8526 0.085 0.932225
psbN__embedding_490 314.1987 233.1926 1.347 0.177913
psbN__embedding_505 17.6603 295.4896 0.060 0.952344
psbT__embedding_604 30.2452 171.4804 0.176 0.860004
psbT__embedding_711 521.1176 150.9538 3.452 0.000560 ***
psbT__embedding_941 252.3240 114.5114 2.203 0.027601 *
psbT__embedding_106 384.7161 147.7706 2.603 0.009253 **
psbT__embedding_803 43.8725 103.3712 0.424 0.671278
psbZ__embedding_338 203.0167 186.6933 1.087 0.276892
psbZ__embedding_227 -68.5603 196.9585 -0.348 0.727782
psbZ__embedding_861 29.6209 145.3293 0.204 0.838502
psbZ__embedding_90 -144.6589 128.2703 -1.128 0.259467
[ reached getOption("max.print") -- omitted 106 rows ]
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 4.896 on 5601 degrees of freedom
Multiple R-squared: 0.4499, Adjusted R-squared: 0.4199
F-statistic: 15.02 on 305 and 5601 DF, p-value: < 2.2e-16
# per-gene variable selection and building a combined design matrix
library(data.table)
library(stats)
setDT(clean_embeds)
setDT(df)
pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"
n_top <- 1
sel_list <- vector("list", length(genes))
names(sel_list) <- genes
for (gene in genes) {
gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
gdt <- merge(gdt, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = FALSE, all.y = FALSE)
if (nrow(gdt) < 5) {
warning(sprintf("Gene %s has only %d rows; skipping.", gene, nrow(gdt)))
next
}
cors <- sapply(embed_cols, function(col) cor(gdt[[col]], gdt$pheno, use = "complete.obs"))
available <- names(cors)[!is.na(cors)]
k <- min(n_top, length(available))
top_dims <- names(sort(abs(cors[available]), decreasing = TRUE))[1:k]
sel <- gdt[, c("ID", top_dims), with = FALSE]
setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
sel_list[[gene]] <- sel
}
Quitting from lines 256-307 [unnamed-chunk-16] (makePairsPostMDSFilter.Rmd)
sel_list <- sel_list[!sapply(sel_list, is.null)]
# merge across all IDs (full outer join instead of intersection)
combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = TRUE), sel_list)
# bring phenotype back
combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE, all.y = FALSE)
# median imputation for missing predictors (exclude ID + pheno)
pred_cols <- setdiff(names(combined), c("ID", "pheno"))
for (col in pred_cols) {
med <- median(combined[[col]], na.rm = TRUE)
combined[is.na(get(col)), (col) := med]
}
# quick checks
cat("Samples (IDs) in combined matrix:", nrow(combined), "\n")
Samples (IDs) in combined matrix: 10857
cat("Number of predictors:", length(pred_cols), "\n")
Number of predictors: 61
# fit a plain linear model (may be unstable if predictors >> samples)
# remove ID column for model
lm_formula <- as.formula(paste("pheno ~", paste(setdiff(colnames(combined), c("ID", "pheno")), collapse = " + ")))
lm_fit <- try(lm(lm_formula, data = combined), silent = TRUE)
if (inherits(lm_fit, "try-error")) {
warning("lm failed (probably too many predictors). See glmnet example below.")
} else {
print(summary(lm_fit))
}
Call:
lm(formula = lm_formula, data = combined)
Residuals:
Min 1Q Median 3Q Max
-28.0942 -3.2983 0.8184 3.8462 19.7686
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 76.205 5.541 13.752 < 2e-16 ***
atpA__embedding_876 393.909 85.681 4.597 4.33e-06 ***
atpB__embedding_224 20.269 57.698 0.351 0.725374
atpE__embedding_781 -69.591 48.947 -1.422 0.155127
atpF__embedding_739 -103.899 44.859 -2.316 0.020570 *
atpH__embedding_248 214.038 136.790 1.565 0.117679
atpI__embedding_282 -13.242 47.525 -0.279 0.780532
ccsA__embedding_736 -115.425 31.021 -3.721 0.000200 ***
cemA__embedding_764 -5.409 40.665 -0.133 0.894194
matK__embedding_867 -33.408 30.345 -1.101 0.270935
ndhA__embedding_204 7.551 38.344 0.197 0.843889
ndhB__embedding_16 52.587 65.182 0.807 0.419815
ndhC__embedding_172 -67.894 33.199 -2.045 0.040869 *
ndhD__embedding_610 -129.337 33.226 -3.893 9.98e-05 ***
ndhE__embedding_394 -161.721 49.409 -3.273 0.001067 **
ndhG__embedding_936 189.680 26.602 7.130 1.07e-12 ***
ndhH__embedding_448 -327.157 101.589 -3.220 0.001284 **
ndhI__embedding_356 -118.316 39.630 -2.986 0.002837 **
ndhJ__embedding_910 -17.921 54.810 -0.327 0.743706
ndhK__embedding_939 49.281 38.795 1.270 0.204003
petA__embedding_612 -192.007 47.767 -4.020 5.87e-05 ***
petG__embedding_425 76.336 30.410 2.510 0.012079 *
petN__embedding_958 -364.799 115.487 -3.159 0.001589 **
psaA__embedding_484 617.128 138.603 4.452 8.57e-06 ***
psaB__embedding_765 44.646 121.998 0.366 0.714402
psaC__embedding_59 -128.172 99.698 -1.286 0.198608
psaJ__embedding_168 391.925 51.369 7.630 2.55e-14 ***
psbA__embedding_514 1464.388 200.833 7.292 3.28e-13 ***
psbB__embedding_269 723.484 96.089 7.529 5.51e-14 ***
psbC__embedding_438 767.096 114.227 6.716 1.97e-11 ***
psbD__embedding_351 -532.484 105.238 -5.060 4.27e-07 ***
psbE__embedding_261 259.766 85.849 3.026 0.002485 **
psbF__embedding_163 891.369 115.667 7.706 1.41e-14 ***
psbH__embedding_712 -145.498 37.374 -3.893 9.96e-05 ***
psbI__embedding_490 -994.834 143.682 -6.924 4.64e-12 ***
psbJ__embedding_72 988.822 67.731 14.599 < 2e-16 ***
psbK__embedding_53 127.885 40.554 3.153 0.001618 **
psbM__embedding_53 212.570 51.348 4.140 3.50e-05 ***
psbN__embedding_527 207.500 45.044 4.607 4.14e-06 ***
psbT__embedding_604 -87.724 42.380 -2.070 0.038485 *
psbZ__embedding_338 336.252 58.052 5.792 7.14e-09 ***
rbcL__embedding_425 -1219.722 78.300 -15.578 < 2e-16 ***
rpl14__embedding_327 19.500 28.170 0.692 0.488828
rpl16__embedding_355 -38.872 30.023 -1.295 0.195434
rpl2__embedding_422 15.997 34.288 0.467 0.640832
rpl20__embedding_422 -32.124 42.109 -0.763 0.445555
rpl23__embedding_434 -12.479 31.976 -0.390 0.696355
rpl33__embedding_446 54.198 27.397 1.978 0.047925 *
rpl36__embedding_661 37.785 34.438 1.097 0.272593
rpoA__embedding_201 38.182 28.417 1.344 0.179095
rpoB__embedding_774 -327.040 55.123 -5.933 3.07e-09 ***
rpoC1__embedding_238 -141.284 66.400 -2.128 0.033380 *
rps11__embedding_254 -12.219 32.930 -0.371 0.710594
rps14__embedding_437 157.310 28.094 5.599 2.20e-08 ***
rps18__embedding_255 81.100 24.592 3.298 0.000977 ***
rps19__embedding_469 70.034 19.765 3.543 0.000397 ***
rps3__embedding_349 -62.348 27.691 -2.252 0.024371 *
rps4__embedding_363 9.651 38.608 0.250 0.802616
rps7__embedding_838 -183.027 78.300 -2.338 0.019431 *
rps8__embedding_78 79.731 30.892 2.581 0.009865 **
ycf3__embedding_29 -110.264 61.399 -1.796 0.072544 .
ycf4__embedding_373 1.939 57.604 0.034 0.973143
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 5.434 on 10795 degrees of freedom
Multiple R-squared: 0.3022, Adjusted R-squared: 0.2983
F-statistic: 76.65 on 61 and 10795 DF, p-value: < 2.2e-16
plot(lm_fit$fitted.values, combined$pheno, main="pheno ~ top5 embs per gene, n=10857")
library(ggplot2)
set.seed(123) # reproducibility
# predictors
pred_cols <- setdiff(names(combined), c("ID", "pheno"))
## ---- 1. Random 10% holdout ----
idx_test <- sample(seq_len(nrow(combined)), size = ceiling(0.1 * nrow(combined)))
train1 <- combined[-idx_test]
test1 <- combined[idx_test]
fit1 <- lm(pheno ~ ., data = train1[, c("pheno", pred_cols), with = FALSE])
pred1 <- predict(fit1, newdata = test1)
# plot
ggplot(data.frame(obs = test1$pheno, pred = pred1), aes(x = pred, y = obs)) +
geom_point(alpha = 0.5) +
geom_abline(color = "red", linetype = "dashed") +
labs(title = "Random 10% Holdout",
x = "Predicted phenotype",
y = "Observed phenotype") +
theme_minimal()
## ---- 2. Poaceae holdout ----
# get Poaceae IDs
poaceae_ids <- df[grepl("Poales", Taxonomy), ID]
train2 <- combined[!ID %in% poaceae_ids]
test2 <- combined[ID %in% poaceae_ids]
fit2 <- lm(pheno ~ ., data = train2[, c("pheno", pred_cols), with = FALSE])
pred2 <- predict(fit2, newdata = test2)
# plot
ggplot(data.frame(obs = test2$pheno, pred = pred2), aes(x = pred, y = obs)) +
geom_point(alpha = 0.5, color = "darkgreen") +
geom_abline(color = "red", linetype = "dashed") +
labs(title = "Poales Holdout",
x = "Predicted phenotype",
y = "Observed phenotype") +
theme_minimal()
rmse <- function(a,b) sqrt(mean((a - b)^2))
plot(pred2,test2$pheno, main="bio8 ~ top1 embs/gene, holdout Poales",
xlab="Predicted",
ylab="Observed")
abline(a=0,b=1, col="coral")
text(25, 5,
paste0("spearman=",round(cor(pred2,test2$pheno), 3)),col="red")
text(25,2,
paste0("RMSE=",round(rmse(pred2,test2$pheno), 3)),col="red")
for (ord in unique(df$Order)) {
# get Poaceae IDs
poaceae_ids <- df[grepl(ord, Taxonomy), ID]
train2 <- combined[!ID %in% poaceae_ids]
test2 <- combined[ID %in% poaceae_ids]
fit2 <- lm(pheno ~ ., data = train2[, c("pheno", pred_cols), with = FALSE])
pred2 <- predict(fit2, newdata = test2)
# plot
ggplot(data.frame(obs = test2$pheno, pred = pred2), aes(x = pred, y = obs)) +
geom_point(alpha = 0.5, color = "darkgreen") +
geom_abline(color = "red", linetype = "dashed") +
labs(title = paste0(ord, " Holdout"),
x = "Predicted phenotype",
y = "Observed phenotype") +
theme_minimal()
rmse <- function(a,b) sqrt(mean((a - b)^2))
plot(pred2,test2$pheno, main=paste0("bio8 ~ top5 embs/gene, holdout ", ord),
xlab="Predicted",
ylab="Observed")
abline(a=0,b=1, col="coral")
text(quantile(pred2, 0.75), quantile(test2$pheno, 0.25),
paste0("spearman=",round(cor(pred2,test2$pheno), 3)),col="red")
text(quantile(pred2, 0.75), quantile(test2$pheno, 0.25)-2,
paste0("RMSE=",round(rmse(pred2,test2$pheno), 3)),col="red")
}
library(data.table)
library(ggplot2)
setDT(clean_embeds)
setDT(df)
pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"
genes <- unique(clean_embeds$Gene)
rmse <- function(a, b) sqrt(mean((a - b)^2))
# store results
results <- data.table(n_top = integer(),
split = character(),
spearman = numeric(),
rmse = numeric())
set.seed(123)
for (n_top in 1:100) {
sel_list <- vector("list", length(genes))
names(sel_list) <- genes
for (gene in genes) {
gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
gdt <- merge(gdt, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = FALSE, all.y = FALSE)
if (nrow(gdt) < 5) next
cors <- sapply(embed_cols, function(col) cor(gdt[[col]], gdt$pheno, use = "complete.obs"))
available <- names(cors)[!is.na(cors)]
k <- min(n_top, length(available))
top_dims <- names(sort(abs(cors[available]), decreasing = TRUE))[1:k]
sel <- gdt[, c("ID", top_dims), with = FALSE]
setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
sel_list[[gene]] <- sel
}
sel_list <- sel_list[!sapply(sel_list, is.null)]
if (length(sel_list) == 0) next
combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = TRUE), sel_list)
combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE, all.y = FALSE)
# median impute
pred_cols <- setdiff(names(combined), c("ID", "pheno"))
for (col in pred_cols) {
med <- median(combined[[col]], na.rm = TRUE)
combined[is.na(get(col)), (col) := med]
}
## ---- 1. Random 10% holdout ----
idx_test <- sample(seq_len(nrow(combined)), size = ceiling(0.1 * nrow(combined)))
train1 <- combined[-idx_test]
test1 <- combined[idx_test]
fit1 <- lm(pheno ~ ., data = train1[, c("pheno", pred_cols), with = FALSE])
pred1 <- predict(fit1, newdata = test1)
results <- rbind(results, data.table(
n_top = n_top,
split = "random10",
spearman = cor(pred1, test1$pheno, method = "spearman"),
rmse = rmse(pred1, test1$pheno)
))
## ---- 2. Poaceae holdout ----
poaceae_ids <- df[grepl("Poaceae", Taxonomy), ID]
train2 <- combined[!ID %in% poaceae_ids]
test2 <- combined[ID %in% poaceae_ids]
if (nrow(test2) > 0 && nrow(train2) > 10) {
fit2 <- lm(pheno ~ ., data = train2[, c("pheno", pred_cols), with = FALSE])
pred2 <- predict(fit2, newdata = test2)
results <- rbind(results, data.table(
n_top = n_top,
split = "poaceae",
spearman = cor(pred2, test2$pheno, method = "spearman"),
rmse = rmse(pred2, test2$pheno)
))
}
}
Quitting from lines 420-514 [unnamed-chunk-21] (makePairsPostMDSFilter.Rmd)